from nets.gcn_net import GCNNet
from nets.gat_net import GATNet
from nets.gated_gcn_net import GatedGCNNet


def GCN(net_params):
    return GCNNet(net_params)


def GAT(net_params):
    return GATNet(net_params)


def GatedGCN(net_params):
    return GatedGCNNet(net_params)


"""
可以选择任意模型
"""


def gnn_model(MODEL_NAME, net_params):
    models = {
        'GCN': GCN,
        'GAT': GAT,
        'GatedGCN': GatedGCN,
    }

    return models[MODEL_NAME](net_params)
